#  Version: 2020.02.21
#
#  MIT License
#
#  Copyright (c) 2018 Jiankang Deng and Jia Guo
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import datetime
import os
import sys

import mxnet as mx
import numpy as np
from mxnet import ndarray as nd

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
import face_image


def ch_dev(arg_params, aux_params, ctx):
    new_args = dict()
    new_auxs = dict()
    for k, v in arg_params.items():
        new_args[k] = v.as_in_context(ctx)
    for k, v in aux_params.items():
        new_auxs[k] = v.as_in_context(ctx)
    return new_args, new_auxs


def main(args):
    ctx = mx.gpu(args.gpu)
    args.ctx_num = 1
    prop = face_image.load_property(args.data)
    image_size = prop.image_size
    print('image_size', image_size)
    vec = args.model.split(',')
    prefix = vec[0]
    epoch = int(vec[1])
    print('loading', prefix, epoch)
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    all_layers = sym.get_internals()
    sym = all_layers['fc1_output']
    # model = mx.mod.Module.load(prefix, epoch, context = ctx)
    model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
    # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
    model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
    model.set_params(arg_params, aux_params)
    path_imgrec = os.path.join(args.data, 'train.rec')
    path_imgidx = os.path.join(args.data, 'train.idx')
    imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type
    s = imgrec.read_idx(0)
    header, _ = mx.recordio.unpack(s)
    assert header.flag > 0
    print('header0 label', header.label)
    header0 = (int(header.label[0]), int(header.label[1]))
    # assert(header.flag==1)
    imgidx = range(1, int(header.label[0]))
    stat = []
    count = 0
    data = nd.zeros((1, 3, image_size[0], image_size[1]))
    label = nd.zeros((1,))
    for idx in imgidx:
        if len(stat) % 100 == 0:
            print('processing', len(stat))
        s = imgrec.read_idx(idx)
        header, img = mx.recordio.unpack(s)
        img = mx.image.imdecode(img)
        img = nd.transpose(img, axes=(2, 0, 1))
        data[0][:] = img
        # input_blob = np.expand_dims(img.asnumpy(), axis=0)
        # arg_params["data"] = mx.nd.array(input_blob, ctx)
        # arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
        time_now = datetime.datetime.now()
        # exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
        # exe.forward(is_train=False)
        # _embedding = exe.outputs[0].asnumpy().flatten()
        # db = mx.io.DataBatch(data=(data,), label=(label,))
        db = mx.io.DataBatch(data=(data,))
        model.forward(db, is_train=False)
        net_out = model.get_outputs()[0].asnumpy()
        time_now2 = datetime.datetime.now()
        diff = time_now2 - time_now
        stat.append(diff.total_seconds())
        if len(stat) == args.param1:
            break
    stat = stat[10:]
    print('avg infer time', np.mean(stat))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='do network benchmark')
    # general
    parser.add_argument('--gpu', default=0, type=int, help='')
    parser.add_argument('--data', default='', type=str, help='')
    parser.add_argument('--model', default='../model/softmax,50', help='path to load model.')
    parser.add_argument('--batch-size', default=1, type=int, help='')
    parser.add_argument('--param1', default=1010, type=int, help='')
    args = parser.parse_args()
    main(args)
